(*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 *)

(* ModelVerifier: implements additional validation for model files. *)

open Core
open Ast
open Expression
module AccessPath = Analysis.TaintAccessPath
module PyrePysaApi = Interprocedural.PyrePysaApi

type parameter_requirements = {
  anonymous_parameters_positions: Int.Set.t;
  parameter_set: String.Set.t;
  has_star_parameter: bool;
  has_star_star_parameter: bool;
}

let create_parameters_requirements parameters =
  let get_parameters_requirements requirements parameter =
    let open PyrePysaApi.ModelQueries.FunctionParameter in
    match parameter with
    | PositionalOnly { position; name; _ } ->
        {
          requirements with
          anonymous_parameters_positions =
            Set.add requirements.anonymous_parameters_positions position;
          parameter_set =
            (match name with
            | Some name -> Set.add requirements.parameter_set name
            | None -> requirements.parameter_set);
        }
    | Named { name; _ }
    | KeywordOnly { name; _ } ->
        let name = Identifier.sanitized name in
        { requirements with parameter_set = Set.add requirements.parameter_set name }
    | Variable _ -> { requirements with has_star_parameter = true }
    | Keywords _ -> { requirements with has_star_star_parameter = true }
  in
  let init =
    {
      anonymous_parameters_positions = Int.Set.empty;
      parameter_set = String.Set.empty;
      has_star_parameter = false;
      has_star_star_parameter = false;
    }
  in
  List.fold_left parameters ~f:get_parameters_requirements ~init


let model_verification_error ~path ~location kind = { ModelVerificationError.kind; path; location }

let verify_model_syntax ~path ~location ~callable_name ~normalized_model_parameters =
  (* Ensure that the parameter's default value is either not present or `...` to catch common errors
     when declaring models. *)
  let check_default_value { AccessPath.NormalizedParameter.original; _ } =
    match Node.value original with
    | { Parameter.value = None; _ }
    | { Parameter.value = Some { Node.value = Expression.Constant Constant.Ellipsis; _ }; _ } ->
        None
    | { Parameter.value = Some expression; name; _ } ->
        Some
          (model_verification_error
             ~path
             ~location
             (InvalidDefaultValue { callable_name = Reference.show callable_name; name; expression }))
  in
  List.find_map normalized_model_parameters ~f:check_default_value
  |> function
  | Some error -> Error error
  | None -> Ok ()


let verify_imported_model ~path ~location ~callable_name ~imported_name =
  match imported_name with
  | Some imported_name when not (Reference.equal callable_name imported_name) ->
      Error
        (model_verification_error
           ~path
           ~location
           (ImportedFunctionModel { name = callable_name; actual_name = imported_name }))
  | _ -> Ok ()


let model_compatible_errors
    ~callable_signature:
      ({ PyrePysaApi.ModelQueries.FunctionSignature.parameters = callable_parameters; _ } as
      callable_signature)
    ~add_overload_in_error
    ~normalized_model_parameters
  =
  let open ModelVerificationError in
  (* Once a requirement has been satisfied, it is removed from requirement object. At the end, we
     check whether there remains unsatisfied requirements. *)
  let validate_model_parameter
      position
      (errors, requirements)
      { AccessPath.NormalizedParameter.root = model_parameter; _ }
    =
    let open AccessPath.Root in
    match model_parameter with
    | LocalResult
    | Variable _
    | CapturedVariable _ ->
        failwith
          "LocalResult|Variable|CapturedVariable won't be generated by \
           AccessPath.Root.normalize_parameters, and they cannot be compared with type_parameters."
    | PositionalParameter { name; positional_only = true; _ } ->
        if Core.Set.mem requirements.anonymous_parameters_positions position then
          errors, requirements
        else if requirements.has_star_parameter then
          (* If all positional only parameter quota is used, it might be covered by a `*args` *)
          errors, requirements
        else
          ( IncompatibleModelError.UnexpectedPositionalOnlyParameter
              {
                name;
                position;
                valid_positions = Core.Set.elements requirements.anonymous_parameters_positions;
              }
            :: errors,
            requirements )
    | PositionalParameter { name; positional_only = false; _ }
    | NamedParameter { name } ->
        let name = Identifier.sanitized name in
        let { parameter_set; has_star_parameter; has_star_star_parameter; _ } = requirements in
        (* Consume an required or optional named parameter. *)
        if Core.Set.mem parameter_set name then
          let parameter_set = Core.Set.remove parameter_set name in
          errors, { requirements with parameter_set }
        else if has_star_star_parameter then
          (* If the name is not found in the set, it is covered by `**kwargs` *)
          errors, requirements
        else if has_star_parameter then (* positional parameters can be covered by `*args` *)
          match model_parameter with
          | PositionalParameter _ -> errors, requirements
          | _ -> UnexpectedNamedParameter name :: errors, requirements
        else
          IncompatibleModelError.UnexpectedNamedParameter name :: errors, requirements
    | StarParameter _ ->
        if requirements.has_star_parameter then
          errors, requirements
        else
          IncompatibleModelError.UnexpectedStarredParameter :: errors, requirements
    | StarStarParameter _ ->
        if requirements.has_star_star_parameter then
          errors, requirements
        else
          IncompatibleModelError.UnexpectedDoubleStarredParameter :: errors, requirements
  in
  match callable_parameters with
  | PyrePysaApi.ModelQueries.FunctionParameters.List parameters ->
      let parameter_requirements = create_parameters_requirements parameters in
      let errors, _ =
        List.foldi
          normalized_model_parameters
          ~f:validate_model_parameter
          ~init:([], parameter_requirements)
      in
      List.map
        ~f:(fun reason ->
          {
            IncompatibleModelError.reason;
            overload = Option.some_if add_overload_in_error callable_signature;
          })
        errors
  | _ -> []


let verify_signature
    ~path
    ~location
    ~normalized_model_parameters
    ~name:callable_name
    ~imported_name
    callable_signatures
  =
  let open Result in
  verify_model_syntax ~path ~location ~callable_name ~normalized_model_parameters
  >>= fun () ->
  verify_imported_model ~path ~location ~callable_name ~imported_name
  >>= fun () ->
  match callable_signatures with
  | Some callable_signatures ->
      let add_overload_in_error = List.length callable_signatures > 1 in
      let errors =
        let errors_in_overloads =
          List.map callable_signatures ~f:(fun callable_signature ->
              model_compatible_errors
                ~callable_signature
                ~add_overload_in_error
                ~normalized_model_parameters)
        in
        if List.find ~f:List.is_empty errors_in_overloads |> Option.is_some then
          [] (* No error for at least one overload. *)
        else
          List.concat errors_in_overloads
      in
      if not (List.is_empty errors) then
        Error
          (model_verification_error
             ~path
             ~location
             (IncompatibleModelError
                { name = Reference.show callable_name; callable_signatures; errors }))
      else
        Ok ()
  | _ -> Ok ()


let verify_global_attribute ~path ~location ~pyre_api ~name =
  let module Global = PyrePysaApi.ModelQueries.Global in
  let global =
    PyrePysaApi.ModelQueries.resolve_qualified_name_to_global
      pyre_api
      ~is_property_getter:false
      ~is_property_setter:false
      name
  in
  match global with
  | Some (Global.Class _) ->
      Error
        (model_verification_error ~path ~location (ModelingClassAsAttribute (Reference.show name)))
  | Some Global.Module ->
      Error
        (model_verification_error ~path ~location (ModelingModuleAsAttribute (Reference.show name)))
  | Some (Global.Function _) ->
      Error
        (model_verification_error
           ~path
           ~location
           (ModelingCallableAsAttribute (Reference.show name)))
  | Some (Global.ModuleGlobal _)
  | Some (Global.ClassAttribute _)
  | Some (Global.UnknownClassAttribute _)
  | Some (Global.UnknownModuleGlobal _)
  | None -> (
      let class_name = Reference.prefix name |> Option.value ~default:Reference.empty in
      let class_attributes =
        PyrePysaApi.ReadOnly.get_class_attributes
          pyre_api
          ~include_generated_attributes:false
          ~only_simple_assignments:false
          (Reference.show class_name)
      in
      match class_attributes, global with
      | Some class_attributes, _ ->
          let attribute_name = Reference.last name in
          if List.mem ~equal:String.equal class_attributes attribute_name then
            Ok ()
          else
            Error
              (model_verification_error
                 ~path
                 ~location
                 (MissingAttribute
                    { class_name = Reference.show class_name; attribute_name = Reference.last name }))
      | None, Some _ -> Ok ()
      | None, None -> (
          let module_name =
            Reference.first (PyrePysaApi.ReadOnly.add_builtins_prefix pyre_api name)
          in
          let module_resolved =
            PyrePysaApi.ModelQueries.resolve_qualified_name_to_global
              pyre_api
              ~is_property_getter:false
              ~is_property_setter:false
              (Reference.create module_name)
          in
          match module_resolved with
          | Some _ ->
              Error
                (model_verification_error
                   ~path
                   ~location
                   (MissingSymbol { module_name; symbol_name = Reference.show name }))
          | None ->
              Error
                (model_verification_error
                   ~path
                   ~location
                   (BaseModuleNotInEnvironment { module_name; name = Reference.show name }))))


(* List of stdlib modules. To keep it short, this only includes modules that we want to annotate
   with taint models. *)
let stdlib_modules =
  String.Set.of_list
    [
      "_socket";
      "argparse";
      "asyncio";
      "bz2";
      "code";
      "copy";
      "email";
      "gzip";
      "hmac";
      "http";
      "linecache";
      "logging";
      "marshal";
      "os";
      "pickle";
      "queue";
      "runpy";
      "shelve";
      "shlex";
      "shutil";
      "smtplib";
      "socket";
      "socketserver";
      "sqlite3";
      "tarfile";
      "tempfile";
      "urllib";
      "wsgiref";
      "xml";
    ]


(* Unlike Pyre, Pyrefly won't type check stdlib modules (from typeshed) that aren't included
   transitively by a source file in under the roots. This means they won't be visible to Pysa
   either, and taint models for those will error. Let's skip these errors since these are
   harmless. *)
let filter_unused_stdlib_modules_errors errors =
  let filter = function
    | {
        ModelVerificationError.kind =
          ModelVerificationError.BaseModuleNotInEnvironment { module_name; _ };
        _;
      }
      when Set.mem stdlib_modules module_name ->
        false
    | _ -> true
  in
  List.filter ~f:filter errors
